# -*- coding: utf-8 -*-
# Copyright © 2007-2008 Stockholm TreeAligner Project
# Author: Torsten Marek <shlomme@gmx.net>
# Licensed under the GNU GPLv2
from __future__ import with_statement

# TODO: convert traces
import sys
import os
import tempfile
from optparse import OptionParser
from itertools import count

from nltk_contrib.tiger.utils.etree_xml import ET
from nltk_contrib.tiger import open_corpus_volatile, open_corpus

import functools

class _E(object):

    def __call__(self, tag, *children, **attrib):
        elem = ET.Element(tag, attrib)
        for item in children:
            if isinstance(item, dict):
                elem.attrib.update(item)
            elif isinstance(item, basestring):
                if len(elem):
                    elem[-1].tail = (elem[-1].tail or "") + item
                else:
                    elem.text = (elem.text or "") + item
            elif ET.iselement(item):
                elem.append(item)
            else:
                raise TypeError("bad argument: %r" % item)
        return elem

    def __getattr__(self, tag):
        return functools.partial(self, tag)

# create factory object
E = _E()

class Graph(object):
    DEFAULT_LABEL = "--"
    def __init__(self, graph_id):
        self._id = graph_id
        self._get_t_id = count(1).next
        self._get_nt_id = count(500).next

        self._terminals = ET.Element("terminals")
        self._nonterminals = ET.Element("nonterminals")
    
    def add_terminal(self, word, pos):
        node_id = "s%i_%i" % (self._id, self._get_t_id())

        ET.SubElement(self._terminals, "t", id=node_id, word=word, pos=pos)
        return (self.DEFAULT_LABEL, node_id), pos

    def add_nonterminal(self, cat, children):
        node_id = "s%i_%i" % (self._id, self._get_nt_id())

        parts = cat.split("-")
        real_cat = parts[0]
        if len(parts) > 1 and not parts[1].isdigit():
            lbl = parts[1].split("=")[0]
        else:
            lbl = self.DEFAULT_LABEL
        nt = ET.SubElement(self._nonterminals, "nt", id=node_id, cat=real_cat)
        for child_lbl, child_id in children:
            ET.SubElement(nt, "edge", idref=child_id, label=child_lbl)
        return (lbl, node_id), real_cat

    def get_xml(self, root_id):
        return E.s(
            E.graph(
                self._terminals,
                self._nonterminals,
                root=root_id),
            id="s%i" % (self._id))

class TreebankConverter(object):
    def __init__(self):
        self._cats = set()
        self._pos_tags = set()
        self._edge_labels = set()
        self._secedge_labels = set()

        self._graphs = []
        self._next_s_id = count(1).next
        

    def _get_children(self, penn_tree, graph):
        children = []
        for child in penn_tree:
            if len(child) == 1 and isinstance(child[0], basestring):
                edge, tag = graph.add_terminal(child[0], child.node)
                self._pos_tags.add(tag)
            else:
                edge, cat = graph.add_nonterminal(child.node, self._get_children(child, graph))
                self._cats.add(cat)
            self._edge_labels.add(edge[0])
            children.append(edge)

        return children

    def add_sentence(self, sentence):
        graph = Graph(self._next_s_id())
        lbl, root_id = self._get_children([sentence], graph)[0]
        self._graphs.append(graph.get_xml(root_id))
    
    def _edgelabels(self):
        e = E.edgelabel()
        for lbl in self._edge_labels:
            e.append(E.value("", name=lbl))
        return e

    def _list_feature(self, name, domain, feature_set):
        f = E.feature(name=name, domain=domain)
        for feature in feature_set:
            f.append(E.value("", name=feature))
        return f

    def _get_annotations(self):
        return E.annotation(
            E.feature(name="word", domain="T"),
            self._list_feature("pos", "T", self._pos_tags),
            self._list_feature("cat", "NT", self._cats),
            self._edgelabels(),
            E.secedgelabel())

    def _get_header(self):
        return E.head(
            E.meta(
                E.name("Penn Treebank Sampler"),
                E.author("Autogenerated"),
                E.date("today"),
                E.description("Penn Treebank Sampler distributed with NLTK converted to TIGER-XML"),
                E.format("Penn-Treebank Format"),
                E.history("Autogenerated")),
            self._get_annotations())

    def _get_body(self):
        body = E.body()
        body[:] = self._graphs
        return body

    def write(self, filename):
        corpus = E.corpus(
            self._get_header(), 
            self._get_body(),
            id="penn-sampler")
        ET.ElementTree(corpus).write(filename, encoding="UTF-8")

def convert_wsj(file_obj):
    from nltk.corpus import treebank
    sys.stderr.write("Converting Penn Treebank sampler...\n")
    tb = TreebankConverter()
    for sentence in treebank.parsed_sents():
        tb.add_sentence(sentence)
    tb.write(file_obj)
    

def demo():
    op = OptionParser()
    op.add_option("-c", "--corpus-file", help="If specified, the Penn sample will be stored/loaded in the given path.", metavar="FILE", action="store")
    op.add_option("-q", "--query", help="The query to be evaluated", default='[cat="NP"] > [cat="PP"]')
    
    options, args = op.parse_args()

    if options.corpus_file:
        corpus_file = options.corpus_file
        if not os.path.exists(corpus_file):
            convert_wsj(corpus_file)
        
        corpus = open_corpus("penn", corpus_file, veeroot=False)
    else:
        sys.stderr.write("Info: Use '-c' to keep converted corpus and skip conversion/indexing.\n")
        corpus_file = tempfile.TemporaryFile()
        convert_wsj(corpus_file)
        corpus_file.seek(0)
        corpus = open_corpus_volatile("penn", corpus_file, veeroot=False)
        corpus_file.close()
        
    print "Corpus size: %i graphs. " % (len(corpus, ))
    evaluator = corpus.get_query_evaluator()
    # Corpus is too small for parallel evaluation to speed up things
    evaluator.set_allow_parallel(False)
    print "Evaluating: %s" % (options.query, )
    query = evaluator.prepare_query(options.query)
    for result in query.evaluate():
        print "Graph: %s, matches: %i" % (corpus.get_xml_graph_id(result[0]), len(result[1]))
    # for access to the graph objects, see nltk_contrib.tiger.corpus and nltk_contrib.tiger.graph


if __name__ == "__main__":
    demo()
